{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "aplace_lora"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from typing import List\n",
    "\n",
    "import fire\n",
    "import torch\n",
    "import transformers\n",
    "from datasets import load_dataset\n",
    "\n",
    "from peft import (\n",
    "    LoraConfig,\n",
    "    get_peft_model,\n",
    "    get_peft_model_state_dict,\n",
    "    set_peft_model_state_dict,\n",
    ")\n",
    "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
    "\n",
    "from utils.prompter import Prompter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model: str= \"models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16\"  # the only required argument\n",
    "data_path: str = \"data/aplaca-cleaned-en\"\n",
    "output_dir: str = \"./lora-alpaca\"\n",
    "# training hyperparams\n",
    "batch_size: int = 128 # 每次迭代中用于训练的样本数量。\n",
    "micro_batch_size: int = 4 # 每个 GPU 核心上处理的样本数量。\n",
    "num_epochs: int = 3 # 3 epochs\n",
    "learning_rate: float = 3e-4\n",
    "cutoff_len: int = 256 # max length of input to model\n",
    "val_set_size: int = 2000 # size of validation set\n",
    "# lora hyperparams\n",
    "lora_r: int = 8 # rank of the low rank matrix\n",
    "lora_alpha: int = 16 # alpha of the low rank matrix\n",
    "lora_dropout: float = 0.05 # dropout of the low rank matrix\n",
    "lora_target_modules: List[str] = [\n",
    "    \"q_proj\",\n",
    "    \"v_proj\",\n",
    "] # 要应用 LoRA 方法的模型模块名称。QV模块\n",
    "# llm hyperparams\n",
    "train_on_inputs: bool = True  # if False, masks out inputs in loss是否在损失计算中包含输入。\n",
    "add_eos_token: bool = False # 是否在输出中添加 EOS 标记。\n",
    "group_by_length: bool = False  # faster, but produces an odd training loss curve是否根据序列长度对数据进行分组\n",
    "# wandb params\n",
    "wandb_project: str = \"\" # wandb project name\n",
    "wandb_run_name: str = \"\" # wandb run name\n",
    "wandb_watch: str = \"\"  # options: false | gradients | all 是否让 wandb 观察模型的梯度或全部参数。\n",
    "wandb_log_model: str = \"\"  # options: false | true是否在 wandb 中记录模型的权重。\n",
    "resume_from_checkpoint: str = None  # either training checkpoint or final adapter 表示从哪个检查点恢复训练，可以是训练过程中的检查点或最终的适配器。\n",
    "prompt_template_name: str = \"alpaca\"  # The prompt template to use, will default to alpaca.    # 提示模板的名称，默认为 alpaca。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Alpaca-LoRA model with params:\n",
      "base_model: models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16\n",
      "data_path: data/aplaca-cleaned-en\n",
      "output_dir: ./lora-alpaca\n",
      "batch_size: 128\n",
      "micro_batch_size: 4\n",
      "num_epochs: 3\n",
      "learning_rate: 0.0003\n",
      "cutoff_len: 256\n",
      "val_set_size: 2000\n",
      "lora_r: 8\n",
      "lora_alpha: 16\n",
      "lora_dropout: 0.05\n",
      "lora_target_modules: ['q_proj', 'v_proj']\n",
      "train_on_inputs: True\n",
      "add_eos_token: False\n",
      "group_by_length: False\n",
      "wandb_project: \n",
      "wandb_run_name: \n",
      "wandb_watch: \n",
      "wandb_log_model: \n",
      "resume_from_checkpoint: False\n",
      "prompt template: alpaca\n",
      "\n"
     ]
    }
   ],
   "source": [
    "if int(os.environ.get(\"LOCAL_RANK\", 0)) == 0:\n",
    "    print(\n",
    "        f\"Training Alpaca-LoRA model with params:\\n\"\n",
    "        f\"base_model: {base_model}\\n\"\n",
    "        f\"data_path: {data_path}\\n\"\n",
    "        f\"output_dir: {output_dir}\\n\"\n",
    "        f\"batch_size: {batch_size}\\n\"\n",
    "        f\"micro_batch_size: {micro_batch_size}\\n\"\n",
    "        f\"num_epochs: {num_epochs}\\n\"\n",
    "        f\"learning_rate: {learning_rate}\\n\"\n",
    "        f\"cutoff_len: {cutoff_len}\\n\"\n",
    "        f\"val_set_size: {val_set_size}\\n\"\n",
    "        f\"lora_r: {lora_r}\\n\"\n",
    "        f\"lora_alpha: {lora_alpha}\\n\"\n",
    "        f\"lora_dropout: {lora_dropout}\\n\"\n",
    "        f\"lora_target_modules: {lora_target_modules}\\n\"\n",
    "        f\"train_on_inputs: {train_on_inputs}\\n\"\n",
    "        f\"add_eos_token: {add_eos_token}\\n\"\n",
    "        f\"group_by_length: {group_by_length}\\n\"\n",
    "        f\"wandb_project: {wandb_project}\\n\"\n",
    "        f\"wandb_run_name: {wandb_run_name}\\n\"\n",
    "        f\"wandb_watch: {wandb_watch}\\n\"\n",
    "        f\"wandb_log_model: {wandb_log_model}\\n\"\n",
    "        f\"resume_from_checkpoint: {resume_from_checkpoint or False}\\n\"\n",
    "        f\"prompt template: {prompt_template_name}\\n\"\n",
    "    )\n",
    "assert ( \n",
    "    base_model\n",
    "), \"Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gradient_accumulation_steps = batch_size // micro_batch_size # 计算梯度累积的步数 128/4=32 32个step后再更新梯度\n",
    "\n",
    "prompter = Prompter(prompt_template_name) # 创建prompt模板\n",
    "\n",
    "device_map = \"auto\" # 自动分配设备\n",
    "world_size = int(os.environ.get(\"WORLD_SIZE\", 1)) # 获取world_size参与训练的总设备数\n",
    "ddp = world_size != 1 # 若使用分布式数据并行Distributed Data Parallel，需要多个训练设备\n",
    "if ddp:\n",
    "        device_map = {\"\": int(os.environ.get(\"LOCAL_RANK\") or 0)} # device_map 指定了模型的哪些部分应该放置在哪些设备上\n",
    "        gradient_accumulation_steps = gradient_accumulation_steps // world_size #调整了梯度累积的步数，使其适应分布式训练环境\n",
    "ddp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "use_wandb = len(wandb_project) > 0 or (\n",
    "    \"WANDB_PROJECT\" in os.environ and len(os.environ[\"WANDB_PROJECT\"]) > 0\n",
    ") # 是否使用wandb工具Weights & Biases（简称W&B 或 wandb）是一个用于机器学习实验跟踪、可视化和项目管理的软件工具。\n",
    "# Only overwrite environ if wandb param passed\n",
    "if len(wandb_project) > 0:\n",
    "    os.environ[\"WANDB_PROJECT\"] = wandb_project\n",
    "if len(wandb_watch) > 0:\n",
    "    os.environ[\"WANDB_WATCH\"] = wandb_watch\n",
    "if len(wandb_log_model) > 0:\n",
    "    os.environ[\"WANDB_LOG_MODEL\"] = wandb_log_model\n",
    "use_wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n",
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.29s/it]\n"
     ]
    }
   ],
   "source": [
    "model = LlamaForCausalLM.from_pretrained( # 用llamaForCausalLM类加载模型\n",
    "    base_model,\n",
    "    load_in_8bit=True, # 量化模型8bit\n",
    "    torch_dtype=torch.float16, # 数据为16浮点精度\n",
    "    device_map=device_map,\n",
    ")\n",
    "\n",
    "tokenizer = LlamaTokenizer.from_pretrained(base_model) # 加载与模型配套的分词器\n",
    "\n",
    "tokenizer.pad_token_id = ( # 设置分词器的填充（pad）标记的ID为 0，在处理批次（batch）数据时，通常需要将所有序列填充到相同的长度，填充标记用于表示序列的填充部分。\n",
    "    0  # unk. we want this to be different from the eos token\n",
    ")\n",
    "tokenizer.padding_side = \"left\"  # Allow batched inference填充操作应该在序列的左侧进行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<bound method Module.state_dict of LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaSdpaAttention(\n",
       "          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "          (v_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "          (rotary_emb): LlamaRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)\n",
       "          (up_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)\n",
       "          (down_proj): Linear8bitLt(in_features=11008, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm()\n",
       "        (post_attention_layernorm): LlamaRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")>"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.state_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(prompt, add_eos_token=True):\n",
    "        # there's probably a way to do this with the tokenizer settings\n",
    "        # but again, gotta move fast\n",
    "        # 处理输入的提示（prompt），并根据需要添加结束标记（end-of-sentence token）\n",
    "        result = tokenizer( # 对提示分词\n",
    "            prompt,\n",
    "            truncation=True, # 如果输入超过了最大长度，将截断输入。\n",
    "            max_length=cutoff_len,  # 截断长度\n",
    "            padding=False, # 不将序列填充到最大长度。\n",
    "            return_tensors=None, # 不返回任何特定的张量类型，通常用于框架无关的分词。\n",
    "        )\n",
    "        if (\n",
    "            result[\"input_ids\"][-1] != tokenizer.eos_token_id # 如果结尾不是结束标记\n",
    "            and len(result[\"input_ids\"]) < cutoff_len\n",
    "            and add_eos_token\n",
    "        ):\n",
    "            result[\"input_ids\"].append(tokenizer.eos_token_id)\n",
    "            result[\"attention_mask\"].append(1) # 1表示有效输入(结束标记)\n",
    "\n",
    "        result[\"labels\"] = result[\"input_ids\"].copy()\n",
    "\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_and_tokenize_prompt(data_point):\n",
    "    full_prompt = prompter.generate_prompt( # 生成完整提示\n",
    "        data_point[\"instruction\"],\n",
    "        data_point[\"input\"],\n",
    "        data_point[\"output\"],\n",
    "    )\n",
    "    tokenized_full_prompt = tokenize(full_prompt) # 分词\n",
    "    if not train_on_inputs: # 训练时不使用输入（input）部分来计算损失。\n",
    "        user_prompt = prompter.generate_prompt(\n",
    "            data_point[\"instruction\"], data_point[\"input\"]\n",
    "        )\n",
    "        tokenized_user_prompt = tokenize(\n",
    "            user_prompt, add_eos_token=add_eos_token\n",
    "        )\n",
    "        user_prompt_len = len(tokenized_user_prompt[\"input_ids\"])\n",
    "\n",
    "        if add_eos_token:\n",
    "            user_prompt_len -= 1\n",
    "\n",
    "        tokenized_full_prompt[\"labels\"] = [\n",
    "            -100\n",
    "        ] * user_prompt_len + tokenized_full_prompt[\"labels\"][\n",
    "            user_prompt_len:\n",
    "        ]  # could be sped up, probably\n",
    "    return tokenized_full_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [1, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29889, 14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 30982, 31695, 31863, 31577, 30210, 30457, 30502, 31302, 30858, 30267, 13, 13, 2277, 29937, 13291, 29901, 13, 30651, 30557, 30392, 30982, 31695, 31863, 31577, 30210, 30457, 30502, 31302, 30858, 30383, 13, 13, 29896, 29889, 29871, 30982, 31695, 31687, 30988, 31704, 30846, 30267, 31951, 30408, 232, 132, 157, 236, 131, 133, 30948, 30210, 31687, 30988, 31894, 30846, 30214, 30847, 233, 152, 166, 233, 176, 168, 30330, 235, 186, 148, 233, 176, 168, 31391, 233, 187, 187, 233, 182, 182, 30214, 30815, 231, 194, 134, 31174, 30869, 235, 164, 131, 31624, 31863, 31577, 30214, 232, 165, 161, 232, 191, 189, 235, 133, 143, 235, 133, 140, 31074, 31180, 30214, 31666, 30417, 31931, 30909, 232, 138, 146, 31022, 30988, 30908, 30267, 13, 13, 29906, 29889, 29871, 232, 160, 138, 235, 164, 164, 236, 168, 177, 31855, 30267, 31951, 30408, 31855, 30406, 30374, 236, 181, 159, 30210, 235, 151, 175, 31854, 30330, 30716, 30801, 30330, 30753, 31112, 30834, 30503, 235, 135, 133, 235, 133, 173, 232, 147, 174, 31180, 231, 192, 145, 30210, 235, 158, 142, 30868, 235, 183, 171, 31855, 30834, 30214, 236, 132, 194, 232, 136, 144, 30528, 234, 182, 153, 30330, 30528, 235, 135, 133, 235, 133, 173, 30503, 30666, 31041, 31855, 31399, 30214, 30651, 30982, 31695, 31863, 31577, 30210, 236, 168, 177, 31855, 231, 188, 163, 233, 134, 178, 30267, 13, 13, 29941, 29889, 29871, 234, 160], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [1, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29889, 14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 30982, 31695, 31863, 31577, 30210, 30457, 30502, 31302, 30858, 30267, 13, 13, 2277, 29937, 13291, 29901, 13, 30651, 30557, 30392, 30982, 31695, 31863, 31577, 30210, 30457, 30502, 31302, 30858, 30383, 13, 13, 29896, 29889, 29871, 30982, 31695, 31687, 30988, 31704, 30846, 30267, 31951, 30408, 232, 132, 157, 236, 131, 133, 30948, 30210, 31687, 30988, 31894, 30846, 30214, 30847, 233, 152, 166, 233, 176, 168, 30330, 235, 186, 148, 233, 176, 168, 31391, 233, 187, 187, 233, 182, 182, 30214, 30815, 231, 194, 134, 31174, 30869, 235, 164, 131, 31624, 31863, 31577, 30214, 232, 165, 161, 232, 191, 189, 235, 133, 143, 235, 133, 140, 31074, 31180, 30214, 31666, 30417, 31931, 30909, 232, 138, 146, 31022, 30988, 30908, 30267, 13, 13, 29906, 29889, 29871, 232, 160, 138, 235, 164, 164, 236, 168, 177, 31855, 30267, 31951, 30408, 31855, 30406, 30374, 236, 181, 159, 30210, 235, 151, 175, 31854, 30330, 30716, 30801, 30330, 30753, 31112, 30834, 30503, 235, 135, 133, 235, 133, 173, 232, 147, 174, 31180, 231, 192, 145, 30210, 235, 158, 142, 30868, 235, 183, 171, 31855, 30834, 30214, 236, 132, 194, 232, 136, 144, 30528, 234, 182, 153, 30330, 30528, 235, 135, 133, 235, 133, 173, 30503, 30666, 31041, 31855, 31399, 30214, 30651, 30982, 31695, 31863, 31577, 30210, 236, 168, 177, 31855, 231, 188, 163, 233, 134, 178, 30267, 13, 13, 29941, 29889, 29871, 234, 160]}"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_point =  {\n",
    "    \"instruction\": \"保持健康的三个提示。\",\n",
    "    \"input\": \"\",\n",
    "    \"output\": \"以下是保持健康的三个提示：\\n\\n1. 保持身体活动。每天做适当的身体运动，如散步、跑步或游泳，能促进心血管健康，增强肌肉力量，并有助于减少体重。\\n\\n2. 均衡饮食。每天食用新鲜的蔬菜、水果、全谷物和脂肪含量低的蛋白质食物，避免高糖、高脂肪和加工食品，以保持健康的饮食习惯。\\n\\n3. 睡眠充足。睡眠对人体健康至关重要，成年人每天应保证 7-8 小时的睡眠。良好的睡眠有助于减轻压力，促进身体恢复，并提高注意力和记忆力。\"\n",
    "  }\n",
    "generate_and_tokenize_prompt(data_point=data_point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0622\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "config = LoraConfig(\n",
    "    r=lora_r,   # r is the rank of the low-rank matrix\n",
    "    lora_alpha=lora_alpha,  # lora_alpha is the scaling factor for the low-rank matrix\n",
    "    target_modules=lora_target_modules, #要在哪些模型模块上应用LoRA,如注意力层\n",
    "    lora_dropout=lora_dropout,  #dropout率\n",
    "    bias=\"none\",    # bias=\"none\"表示不添加偏置\n",
    "    task_type=\"CAUSAL_LM\",  #表示因果语言模型（用于生成文本的任务）\n",
    ")\n",
    "model = get_peft_model(model, config)   #将模型和配置合并\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['instruction', 'input', 'output'],\n",
       "        num_rows: 51760\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if data_path.endswith(\".json\") or data_path.endswith(\".jsonl\"):\n",
    "    data = load_dataset(\"json\", data_files=data_path)\n",
    "else:\n",
    "    data = load_dataset(data_path)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 49760/49760 [01:12<00:00, 682.25 examples/s]\n",
      "Map: 100%|██████████| 2000/2000 [00:02<00:00, 688.79 examples/s]\n"
     ]
    }
   ],
   "source": [
    "if resume_from_checkpoint:\n",
    "    # Check the available weights and load them\n",
    "    checkpoint_name = os.path.join(\n",
    "        resume_from_checkpoint, \"pytorch_model.bin\"\n",
    "    )  # Full checkpoint\n",
    "    if not os.path.exists(checkpoint_name):\n",
    "        checkpoint_name = os.path.join(\n",
    "            resume_from_checkpoint, \"adapter_model.bin\"\n",
    "        )  # only LoRA model - LoRA config above has to fit\n",
    "        resume_from_checkpoint = (\n",
    "            False  # So the trainer won't try loading its state\n",
    "        )\n",
    "    # The two files above have a different name depending on how they were saved, but are actually the same.\n",
    "    if os.path.exists(checkpoint_name):\n",
    "        print(f\"Restarting from {checkpoint_name}\")\n",
    "        adapters_weights = torch.load(checkpoint_name)\n",
    "        set_peft_model_state_dict(model, adapters_weights)\n",
    "    else:\n",
    "        print(f\"Checkpoint {checkpoint_name} not found\")\n",
    "if val_set_size > 0:\n",
    "    train_val = data[\"train\"].train_test_split( # 从训练集中分割出指定大小的验证集\n",
    "        test_size=val_set_size, shuffle=True, seed=42   # shuffle随机打乱\n",
    "    )\n",
    "    train_data = (\n",
    "        train_val[\"train\"].shuffle().map(generate_and_tokenize_prompt) #对每个数据项应用generate_and_tokenize_prompt函数\n",
    "    )\n",
    "    val_data = (\n",
    "        train_val[\"test\"].shuffle().map(generate_and_tokenize_prompt)\n",
    "    )\n",
    "else:\n",
    "    train_data = data[\"train\"].shuffle().map(generate_and_tokenize_prompt)\n",
    "    val_data = None\n",
    "\n",
    "if not ddp and torch.cuda.device_count() > 1:\n",
    "    # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available\n",
    "    model.is_parallelizable = True\n",
    "    model.model_parallel = True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'instruction': \"Write a JavaScript code to display an alert message when the 'Submit' button is pressed.\", 'input': '', 'output': 'Here is the JavaScript code you can use to display an alert message when the \\'Submit\\' button is pressed:\\n\\n```javascript\\ndocument.getElementById(\\'submit-button\\').addEventListener(\\'click\\', function() {\\n  alert(\\'Your form has been submitted!\\');\\n});\\n```\\n\\nThis code selects the \"Submit\" button by its `id` attribute using `document.getElementById(\\'submit-button\\')`. It then adds an event listener to the button that listens for click events, using the `addEventListener` method. When the button is clicked, the callback function is executed, which displays an alert message using the `alert` function. In this case, the alert message says \"Your form has been submitted!\".\\n\\nMake sure to assign the proper `id` attribute (in this example, `id=\"submit-button\"`) to the submit button in HTML, so that the element can be properly selected by the script above.'}\n",
      "{'instruction': 'Compare two different albums.', 'input': 'Beyonce, Lemonade and Rihanna, Anti', 'output': 'Lemonade by Beyonce and Anti by Rihanna are two critically-acclaimed albums released by two powerhouse female artists in the music industry. \\n\\nLemonade, released in 2016, was Beyonce\\'s sixth studio album. The visual album tells a story about infidelity, loss, and forgiveness. It\\'s a powerful album that mixes genres, ranging from hip-hop and R&B to rock, soul and even country. The album features collaborations with artists such as Jack White, Kendrick Lamar, and The Weeknd. Lemonade received widespread critical and commercial success, with many applauding the album\\'s powerful message as well as its originality and artistry.\\n\\nOn the other hand, Anti by Rihanna, also released in 2016, is an album that showcases the singer\\'s evolution as an artist. The album features a mix of genres, including pop, R&B, soul, and dancehall, and includes a variety of musical styles and moods. Anti is an album that focuses on themes of love, self-discovery, and self-assurance. Standout tracks from the album include \"Work,\" \"Kiss It Better,\" and \"Love on the Brain\". Like Lemonade, Anti was a commercial and critical success, applauded for its originality and artistic expression.\\n\\nIn conclusion, both albums are unique and powerful in their own right. While Lemonade\\'s central theme centers around infidelity, Anti is more about self-discovery and love. Both albums showcase the artists\\' creative vision and talent, pushing the boundary of music and artistry.', 'input_ids': [1, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 6843, 598, 1023, 1422, 20618, 29889, 13, 13, 2277, 29937, 10567, 29901, 13, 29933, 1032, 10646, 29892, 8836, 265, 1943, 322, 390, 4861, 9713, 29892, 18473, 13, 13, 2277, 29937, 13291, 29901, 13, 29931, 9857, 1943, 491, 18502, 10646, 322, 18473, 491, 390, 4861, 9713, 526, 1023, 3994, 1711, 29899, 5753, 13190, 20618, 5492, 491, 1023, 3081, 8697, 12944, 17906, 297, 278, 4696, 13661, 29889, 29871, 13, 13, 29931, 9857, 1943, 29892, 5492, 297, 29871, 29906, 29900, 29896, 29953, 29892, 471, 18502, 10646, 29915, 29879, 25963, 8693, 3769, 29889, 450, 7604, 3769, 10603, 263, 5828, 1048, 3041, 10652, 537, 29892, 6410, 29892, 322, 18879, 20193, 29889, 739, 29915, 29879, 263, 13988, 3769, 393, 6837, 267, 2531, 690, 29892, 364, 9776, 515, 21464, 29899, 29882, 459, 322, 390, 29987, 29933, 304, 7679, 29892, 10752, 322, 1584, 4234, 29889, 450, 3769, 5680, 11465, 800, 411, 17906, 1316, 408, 5457, 8037, 29892, 476, 355, 9131, 365, 8715, 29892, 322, 450, 15511, 299, 29889, 8836, 265, 1943, 4520, 281, 2247, 29886, 949, 12187, 322, 12128, 2551, 29892, 411, 1784, 623, 433, 566, 292, 278, 3769, 29915, 29879, 13988, 2643, 408, 1532, 408, 967, 2441, 537, 322, 1616, 6020, 29889, 13, 13, 2951, 278, 916, 1361, 29892, 18473, 491, 390, 4861, 9713, 29892, 884, 5492, 297, 29871, 29906, 29900, 29896, 29953, 29892], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [1, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 6843, 598, 1023, 1422, 20618, 29889, 13, 13, 2277, 29937, 10567, 29901, 13, 29933, 1032, 10646, 29892, 8836, 265, 1943, 322, 390, 4861, 9713, 29892, 18473, 13, 13, 2277, 29937, 13291, 29901, 13, 29931, 9857, 1943, 491, 18502, 10646, 322, 18473, 491, 390, 4861, 9713, 526, 1023, 3994, 1711, 29899, 5753, 13190, 20618, 5492, 491, 1023, 3081, 8697, 12944, 17906, 297, 278, 4696, 13661, 29889, 29871, 13, 13, 29931, 9857, 1943, 29892, 5492, 297, 29871, 29906, 29900, 29896, 29953, 29892, 471, 18502, 10646, 29915, 29879, 25963, 8693, 3769, 29889, 450, 7604, 3769, 10603, 263, 5828, 1048, 3041, 10652, 537, 29892, 6410, 29892, 322, 18879, 20193, 29889, 739, 29915, 29879, 263, 13988, 3769, 393, 6837, 267, 2531, 690, 29892, 364, 9776, 515, 21464, 29899, 29882, 459, 322, 390, 29987, 29933, 304, 7679, 29892, 10752, 322, 1584, 4234, 29889, 450, 3769, 5680, 11465, 800, 411, 17906, 1316, 408, 5457, 8037, 29892, 476, 355, 9131, 365, 8715, 29892, 322, 450, 15511, 299, 29889, 8836, 265, 1943, 4520, 281, 2247, 29886, 949, 12187, 322, 12128, 2551, 29892, 411, 1784, 623, 433, 566, 292, 278, 3769, 29915, 29879, 13988, 2643, 408, 1532, 408, 967, 2441, 537, 322, 1616, 6020, 29889, 13, 13, 2951, 278, 916, 1361, 29892, 18473, 491, 390, 4861, 9713, 29892, 884, 5492, 297, 29871, 29906, 29900, 29896, 29953, 29892]}\n",
      "Dataset({\n",
      "    features: ['instruction', 'input', 'output', 'input_ids', 'attention_mask', 'labels'],\n",
      "    num_rows: 2000\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "print(train_val[\"train\"][0])\n",
    "print(train_data[0])\n",
    "print(val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<bound method Module.state_dict of OptimizedModule(\n",
       "  (_orig_mod): PeftModelForCausalLM(\n",
       "    (base_model): LoraModel(\n",
       "      (model): LlamaForCausalLM(\n",
       "        (model): LlamaModel(\n",
       "          (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
       "          (layers): ModuleList(\n",
       "            (0-31): 32 x LlamaDecoderLayer(\n",
       "              (self_attn): LlamaSdpaAttention(\n",
       "                (q_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=4096, out_features=8, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=8, out_features=4096, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (k_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "                (v_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=4096, out_features=8, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=8, out_features=4096, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n",
       "                (rotary_emb): LlamaRotaryEmbedding()\n",
       "              )\n",
       "              (mlp): LlamaMLP(\n",
       "                (gate_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)\n",
       "                (up_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)\n",
       "                (down_proj): Linear8bitLt(in_features=11008, out_features=4096, bias=False)\n",
       "                (act_fn): SiLU()\n",
       "              )\n",
       "              (input_layernorm): LlamaRMSNorm()\n",
       "              (post_attention_layernorm): LlamaRMSNorm()\n",
       "            )\n",
       "          )\n",
       "          (norm): LlamaRMSNorm()\n",
       "        )\n",
       "        (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")>"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer = transformers.Trainer( #配置训练器\n",
    "    model=model,\n",
    "    train_dataset=train_data,\n",
    "    eval_dataset=val_data,\n",
    "    args=transformers.TrainingArguments(\n",
    "        per_device_train_batch_size=micro_batch_size,\n",
    "        gradient_accumulation_steps=gradient_accumulation_steps,\n",
    "        warmup_steps=100,\n",
    "        num_train_epochs=num_epochs,\n",
    "        learning_rate=learning_rate,\n",
    "        fp16=True,\n",
    "        logging_steps=10,\n",
    "        optim=\"adamw_torch\",\n",
    "        evaluation_strategy=\"steps\" if val_set_size > 0 else \"no\",\n",
    "        save_strategy=\"steps\",\n",
    "        eval_steps=200 if val_set_size > 0 else None,\n",
    "        save_steps=200,\n",
    "        output_dir=output_dir,\n",
    "        save_total_limit=3,\n",
    "        load_best_model_at_end=True if val_set_size > 0 else False,\n",
    "        ddp_find_unused_parameters=False if ddp else None,\n",
    "        group_by_length=group_by_length,\n",
    "        report_to=\"wandb\" if use_wandb else None,\n",
    "        run_name=wandb_run_name if use_wandb else None,\n",
    "    ),\n",
    "    data_collator=transformers.DataCollatorForSeq2Seq(  #数据整理\n",
    "        tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n",
    "    ),\n",
    ")\n",
    "model.config.use_cache = False  #禁用了模型的缓存功能\n",
    "\n",
    "old_state_dict = model.state_dict   #修改模型状态字典\n",
    "model.state_dict = (\n",
    "    lambda self, *_, **__: get_peft_model_state_dict(\n",
    "        self, old_state_dict()\n",
    "    )\n",
    ").__get__(model, type(model))\n",
    "\n",
    "if torch.__version__ >= \"2\" and sys.platform != \"win32\":\n",
    "    model = torch.compile(model)    #torch.compile对模型进行编译\n",
    "model.state_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train(resume_from_checkpoint=resume_from_checkpoint)\n",
    "\n",
    "if not os.path.exists(output_dir):  \n",
    "    os.makedirs(output_dir)  \n",
    "model.save_pretrained(output_dir)\n",
    "\n",
    "print(\n",
    "    \"\\n If there's a warning about missing keys above, please disregard :)\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ChatTTS",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
